package MgoPool

import (
	"bytes"
	"fmt"
	"gopkg.in/mgo.v2"
	"log"
	"os"
	"strings"
	"sync"
	"time"
)

var (
	DEFULA_LOG_DIR = "/Log"
	UTFALL_DATE    = "2006-01-02"
	cstZone        = time.FixedZone("CST", 8*3600)
)

type mSession struct {
	session *mgo.Session
}

func newMSession(mgoSession *mgo.Session) *mSession {
	return &mSession{session: mgoSession}
}

type mBuilder struct {
	a []MgoBuilderApi
}

func newMBuilder() *mBuilder {
	return &mBuilder{a: []MgoBuilderApi{}}
}

func (this *mBuilder) setBuilders(mb []MgoBuilderApi) {
	mbs := make([]MgoBuilderApi, len(mb))
	copy(mbs, mb)
	this.a = append(this.a, mbs...)
}

var MPool = NewMgoPool()

/**
mgo连接池
*/
type MgoPool struct {
	mBuilder   *mBuilder
	sessions   map[string][]*mSession   //mgo连接池
	builderMap map[string]MgoBuilderApi //builder k-val 存放

	sessionIndex int32 //当前使用连接的下标

	loadBalance *loadBalance

	sessionLock sync.Mutex
	connecLock  sync.Mutex
}

func NewMgoPool() *MgoPool {
	pool := &MgoPool{
		mBuilder:     newMBuilder(),
		sessions:     make(map[string][]*mSession),
		builderMap:   make(map[string]MgoBuilderApi),
		sessionIndex: 0,
		loadBalance:  newLoadBalance(),
	}
	return pool
}

func (this *MgoPool) SetBuilders(builders ...MgoBuilderApi) {
	this.mBuilder.setBuilders(builders)
}

/**
链接
*/
func (this *MgoPool) Connection() error {
	if this.mBuilder == nil || this.mBuilder.a == nil || len(this.mBuilder.a) == 0 {
		return MGO_ERROR_BUILDER_EMPTY
	}
	//组装builder
	return this.createBuilder()
}

func (this *MgoPool) GetSession(builderName string, s func(session *mgo.Session, db *mgo.Database) error) error {

	mSessions, buidler, err := this.getMSession(builderName)
	if err != nil {

		return err
	}

	mgoSession := mSessions.session.Copy()
	defer func() {
		mgoSession.Close()
	}()
	if s != nil {
		dataBase := mgoSession.DB(buidler.GetDataBase())
		mgoSession.SetMode(mgo.Strong, true)
		errs := s(mgoSession, dataBase)
		if errs != nil {
			//return NewMgoError(MGO_EXEC_CALLBACK_ERROR.Code, fmt.Sprintf("%s,%s", MGO_EXEC_CALLBACK_ERROR.Message, errs.Error()), errs)
			return MGO_EXEC_CALLBACK_ERROR
		}
	}
	return nil
}

func (this *MgoPool) GetSessionOnly(builderName string) (*mgo.Session, *mgo.Database, error) {
	mSessions, buidler, err := this.getMSession(builderName)
	if err != nil {
		return nil, nil, err
	}
	mgoSessiond := mSessions.session.Copy()
	dataBase := mgoSessiond.DB(buidler.GetDataBase())
	mgoSessiond.SetMode(mgo.Strong, true)
	return mgoSessiond, dataBase, nil
}

func (this *MgoPool) GetBuilder(builderName string) (MgoBuilderApi, error) {
	if len(strings.TrimSpace(builderName)) == 0 {
		return nil, MGO_ERROR_BUILDER_NAME_EMPTY
	}
	this.sessionLock.Lock()
	defer this.sessionLock.Unlock()
	builderNameHash := hash256(builderName)
	builder, ok := this.builderMap[builderNameHash]
	if !ok {
		//return nil, nil, NewMgoError(MGO_ERROR_BUILDER_NOT_EXISTS.Code, fmt.Sprintf("builder name %s not not exists", builderName), errors.New(fmt.Sprintf("builder name %s not not exists", builderName)))
		return nil, MGO_ERROR_BUILDER_NOT_EXISTS
	}
	return builder, nil
}

/**
获取session连接
*/
func (this *MgoPool) getMSession(builderName string) (*mSession, MgoBuilderApi, error) {
	//if len(strings.TrimSpace(builderName)) == 0 {
	//	return nil, nil, MGO_ERROR_BUILDER_NAME_EMPTY
	//}

	builderNameHash := hash256(builderName)
	builder, er := this.GetBuilder(builderName)
	if er != nil {
		//return nil, nil, NewMgoError(MGO_ERROR_BUILDER_NOT_EXISTS.Code, fmt.Sprintf("builder name %s not not exists", builderName), errors.New(fmt.Sprintf("builder name %s not not exists", builderName)))
		return nil, nil, er
	}

	changeSessionIndex := this.sessionIndex
	currentIndex := this.loadBalance.RoundRobin(changeSessionIndex, builder.GetPoolSize())
	this.sessionIndex = currentIndex
	//atomic.AddInt32(&this.sessionIndex, currentNum)
	mSessions := this.sessions[builderNameHash][currentIndex]
	return mSessions, builder, nil

}

func (this *MgoPool) createBuilder() error {
	var err error = nil
	for _, v := range this.mBuilder.a {
		if len(strings.TrimSpace(v.GetBuilderName())) == 0 {
			err = MGO_ERROR_BUILDER_NAME_EMPTY
			break
		}
		sessions, errm := this.createSessions(v)
		if errm != nil {
			err = errm
			break
		}
		this.builderMap[v.GetHashName()] = v
		nItemSession := make([]*mSession, len(sessions))
		copy(nItemSession, sessions)
		this.sessions[v.GetHashName()] = nItemSession
	}
	return err
}

/**
创建 session
*/
func (this *MgoPool) createSessions(builder MgoBuilderApi) ([]*mSession, error) {
	var errMgo error = nil
	var itemPoolSize int32 = DEFAULT_POOL_SIZE
	if builder.GetPoolSize() > 0 {
		itemPoolSize = builder.GetPoolSize()
	}
	dns := this.getMgoDns(builder)
	mgoSession := make([]*mSession, itemPoolSize)
	var i int32 = 0
	for i = 0; i < itemPoolSize; i++ {
		itemSession, err := this.getItemSession(dns, builder)
		if err != nil {
			errMgo = err
			break
		}
		mgoSession[i] = itemSession
	}
	return mgoSession, errMgo
}

/**
获取单个mgo session
*/
func (this *MgoPool) getItemSession(dns string, builder MgoBuilderApi) (*mSession, error) {
	if builder.GetIsDebug() {
		logDir := DEFULA_LOG_DIR
		if len(strings.TrimSpace(builder.GetLogDir())) > 0 {
			logDir = strings.TrimSpace(builder.GetLogDir())
		}
		mgo.SetDebug(true)
		mgo.SetLogger(log.New(logPath(logDir), "\r\n", log.Ldate|log.Ltime))
	}
	session, err := mgo.Dial(dns)
	if err != nil {
		//return nil, NewMgoError(MGO_ERROR_CONNECTON_ERROR.Code, err.Error(), err)
		return nil, MGO_ERROR_CONNECTION_ERROR
	}
	session.SetPoolLimit(int(builder.GetPoolSize()))
	session.SetMode(mgo.Strong, true)
	return newMSession(session), nil
}

func (this *MgoPool) getMgoDns(builder MgoBuilderApi) string {
	var dnsBuf bytes.Buffer
	dnsBuf.WriteString("mongodb://")
	if len(strings.TrimSpace(builder.GetUser())) > 0 {
		dnsBuf.WriteString(builder.GetUser())
		dnsBuf.WriteString(":")
		dnsBuf.WriteString(builder.GetPwd())
		dnsBuf.WriteString("@")
	}
	dnsBuf.WriteString(builder.GetAddr())
	if len(strings.TrimSpace(builder.GetAuthDataBase())) > 0 {
		dnsBuf.WriteString("/")
		dnsBuf.WriteString(builder.GetAuthDataBase())
	}
	dns := dnsBuf.String()
	dnsBuf.Reset()
	return dns
}

func logPath(logPathDir string) *os.File {
	file, err := os.OpenFile(getLogPath(logPathDir, "mgo", true), os.O_APPEND|os.O_WRONLY|os.O_CREATE, 0644)
	if err != nil {
		log.Fatalln("fail to create mgo.log file!")
	}
	return file
}

func getLogPath(logPathDir, logFileName string, isDataLogFormat bool) string {
	//	createLock.Lock()
	formatString := ""
	if isDataLogFormat {
		timeStamp := time.Now().In(cstZone)
		formatString = timeStamp.Format(UTFALL_DATE)
		formatString = "_" + formatString
	}
	logDir := getCurrentPath() + string(os.PathSeparator) + logPathDir + string(os.PathSeparator) + logFileName
	logPath := logDir + string(os.PathSeparator) + logFileName + formatString + ".log"
	_, err := os.Stat(logDir)
	if os.IsNotExist(err) {
		errs := os.MkdirAll(logDir, 777)
		if errs != nil {
			panic(fmt.Sprintf("%s 创建失败", errs))
			return ""
		}
	}
	return logPath
}

/**
aqt获b取当前项目录
*/
func getCurrentPath() string {
	dir, err := os.Getwd()
	if err != nil {
		log.Fatal(err)
	}
	return strings.Replace(dir, "\\", "/", -1)
}
